{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用 Unsloth 对 DeepSeek-R1-Distill-Qwen-1.5B 模型进行 LoRA 微调\n",
    "\n",
    "本 Notebook 展示了如何使用 `unsloth` 库对 `deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B` 模型进行高效的 QLoRA (Low-Rank Adaptation) 微调。\n",
    "\n",
    "整个流程包括：\n",
    "1.  环境准备与库导入\n",
    "2.  加载预训练模型和分词器 (Tokenizer)。\n",
    "3.  在微调前，对模型进行简单的推理测试。\n",
    "4.  下载和格式化训练数据集\n",
    "5.  使用 `unsloth` 的 `FastLanguageModel` 来为模型添加 LoRA 适配器。\n",
    "6.  配置 `SFTTrainer` 监督微调训练配置。\n",
    "7.  启动训练，并观察 Loss 变化情况\n",
    "8.  保存微调后的模型\n",
    "9.  测试训练后的生成结果"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. 环境准备与库导入\n",
    "\n",
    "首先，我们需要安装并导入所有必要的库。`transformers` 用于加载模型和分词器，`unsloth` 用于高效微调，`trl` 提供了 `SFTTrainer`，而 `datasets` 用于处理数据。\n",
    "\n",
    "**注意**: 在运行此 Notebook 之前，请确保已安装所有依赖包：\n",
    "```bash\n",
    "pip install -r requirements.txt\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
      "🦥 Unsloth Zoo will now patch everything to make training faster!\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from unsloth import FastLanguageModel\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig, DataCollatorForSeq2Seq\n",
    "from datasets import Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 加载预训练模型和分词器 (Tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==((====))==  Unsloth 2025.8.5: Fast Qwen2 patching. Transformers: 4.55.2.\n",
      "   \\\\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.568 GB. Platform: Linux.\n",
      "O^O/ \\_/ \\    Torch: 2.7.1+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.3.1\n",
      "\\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.31.post1. FA2 = False]\n",
      " \"-____-\"     Free license: http://github.com/unslothai/unsloth\n",
      "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
     ]
    }
   ],
   "source": [
    "# 定义模型和一些基本参数\n",
    "max_seq_length = 8192\n",
    "dtype = None # None 表示自动选择 (Float16 a T4, V100, BFloat16 a Ampere)\n",
    "load_in_4bit = True # 使用 4bit 量化加载\n",
    "\n",
    "# 这是您的模型标识符，请替换为您正在使用的模型\n",
    "# 例如：\"qwen-1.5b_lora_model\"\n",
    "# model_name = \"qwen-1.5b_lora_model\" \n",
    "# model_name = \"unsloth/DeepSeek-R1-Distill-Qwen-1.5B\" \n",
    "model_name = \"unsloth/DeepSeek-R1-Distill-Qwen-1.5B-unsloth-bnb-4bit\" \n",
    "\n",
    "# 这一步会返回一个经过 Unsloth 优化的模型和一个分词器\n",
    "model, tokenizer = FastLanguageModel.from_pretrained(\n",
    "    model_name = model_name,\n",
    "    max_seq_length = max_seq_length,\n",
    "    dtype = dtype,\n",
    "    load_in_4bit = load_in_4bit,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. 微调前推理测试\n",
    "\n",
    "在对模型进行任何修改之前，我们先用它来生成一段文本，看看原始模型的表现如何。这可以作为我们微调效果的基准参考。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模型推理的 Prompt 模板\n",
    "inference_prompt = \"\"\"以下是一条描述任务的指令，并配有一个提供进一步上下文的输入。\n",
    "请撰写一份恰当的回复，以完成该请求。\n",
    "在回答之前，请仔细思考该问题，并构建一个分步的思考过程，以确保回应的逻辑严谨和内容准确。\n",
    "\n",
    "\n",
    "### Instruction:\n",
    "你是一位医学专家，在临床推理、诊断学和治疗规划方面拥有深厚的专业知识。\n",
    "请回答以下医学问题。\n",
    "\n",
    "### Question:\n",
    "{}\n",
    "\n",
    "### Response:\n",
    "<think>{}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "FastLanguageModel.for_inference(model)\n",
    "\n",
    "question = \"男，28岁，程序员，最近一周每天工作到半夜，感觉头晕、脖子疼，有时候还恶心。\"\n",
    "\n",
    "inputs = tokenizer([inference_prompt.format(question, \"\")], return_tensors=\"pt\").to(\"cuda\")\n",
    "attention_mask = inputs.input_ids.ne(tokenizer.pad_token_id).long().to(\"cuda\")\n",
    "\n",
    "outputs = model.generate(\n",
    "    input_ids=inputs.input_ids,\n",
    "    attention_mask=inputs.attention_mask,\n",
    "    max_new_tokens=1200,\n",
    "    use_cache=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = tokenizer.batch_decode(outputs, skip_special_tokens=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "<think>\n",
      "好的，我现在要处理这个医学问题。让我先仔细阅读一下问题描述。\n",
      "\n",
      "这位28岁的男性，程序员，最近一周每天工作到半夜，感觉头晕、脖子疼，有时候还恶心。首先，我需要考虑他的症状，头晕、脖子疼、恶心，这些都是常见的症状，但结合他的职业和工作时间，可能需要更深入的分析。\n",
      "\n",
      "首先，头晕可能是由于脑力劳动过度导致的，比如长时间工作可能会影响脑供血不足，导致头晕。脖子疼可能与工作相关，比如长时间使用电脑、手机等，也可能与颈椎问题有关。恶心可能与长期使用药物、咖啡因或酒精有关，或者与长期的压力或焦虑有关。\n",
      "\n",
      "接下来，我需要考虑他的职业和工作环境。程序员通常工作在信息处理、编程等岗位，工作环境可能与日常办公不同，比如长时间使用键盘、电脑，可能带来疲劳。工作时间到半夜，可能意味着他有其他工作，如学习、休息，或者可能有加班的情况。\n",
      "\n",
      "现在，我需要构建一个分步的思考过程：\n",
      "\n",
      "1. **症状分析**：首先，明确症状是头晕、脖子疼、恶心，以及持续的时间。这些症状可能与长期的工作压力或疲劳相关。\n",
      "\n",
      "2. **职业背景**：程序员通常工作在信息处理领域，工作时间较长，可能影响脑供血和神经功能。工作环境可能与日常办公不同，比如可能有加班或学习，导致疲劳。\n",
      "\n",
      "3. ** possible causes**：考虑可能的原因，如长期使用电子设备、长时间工作导致的疲劳，颈椎问题，或与长期压力相关。\n",
      "\n",
      "4. **可能的诊断**：可能需要检查脑供血不足、颈椎病、压力或焦虑等，也可能需要考虑药物治疗，如抗抑郁药或抗焦虑药，如果出现恶心。\n",
      "\n",
      "5. **后续评估**：如果诊断为脑供血不足，可能需要进行脑供血灌注（CBN）或其他相关治疗。如果脖子疼或颈椎问题，可能需要进一步检查和治疗。\n",
      "\n",
      "6. **综合治疗**：综合考虑，可能需要药物治疗和相关康复治疗，以缓解症状。\n",
      "\n",
      "7. **建议**：建议咨询医生，评估具体情况，可能需要进一步的检查和治疗。\n",
      "\n",
      "现在，我需要确保回答逻辑严谨，内容准确，并且符合医学知识的范畴。\n",
      "</think>\n",
      "\n",
      "男，28岁，程序员，最近一周每天工作到半夜，感觉头晕、脖子疼，有时候还恶心。\n",
      "\n",
      "首先，考虑他的症状：头晕、脖子疼、恶心，可能与长期工作压力、疲劳、颈椎问题等有关。由于他工作时间到半夜，可能涉及加班或学习，带来持续的疲劳。\n",
      "\n",
      "可能的诊断包括：\n",
      "1. **脑供血不足**：头晕可能与脑供血不足有关，需要考虑CBN治疗。\n",
      "2. **颈椎问题**：脖子疼可能与颈椎问题相关，可能需要进一步检查。\n",
      "3. **压力或焦虑**：恶心可能与长期压力或焦虑有关，可能需要药物治疗或心理干预。\n",
      "\n",
      "建议建议咨询医生，评估具体情况，可能需要进一步的检查和治疗。\n"
     ]
    }
   ],
   "source": [
    "print(response[0].split(\"### Response:\")[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "### 4. 下载和格式化训练数据集\n",
    "\n",
    "\n",
    "医学推理数据集：https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT/viewer/zh\n",
    "\n",
    "![dataset](images/dataset.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模型训练的 Prompt 模板\n",
    "train_prompt = \"\"\"以下是一条描述任务的指令，并配有一个提供进一步上下文的输入。\n",
    "请撰写一份恰当的回复，以完成该请求。\n",
    "在回答之前，请仔细思考该问题，并构建一个分步的思考过程，以确保回应的逻辑严谨和内容准确。\n",
    "\n",
    "\n",
    "### Instruction:\n",
    "你是一位医学专家，在临床推理、诊断学和治疗规划方面拥有深厚的专业知识。\n",
    "请回答以下医学问题。\n",
    "\n",
    "### Question:\n",
    "{}\n",
    "\n",
    "### Response:\n",
    "<think>\n",
    "{}\n",
    "</think>\n",
    "{}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "33c9959f7f974d319eadd04ce53b4748",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/20171 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "EOS_TOKEN = tokenizer.eos_token # 添加 EOS Token\n",
    "\n",
    "def formatting_prompts_func(examples):\n",
    "    inputs = examples[\"Question\"]\n",
    "    cots = examples[\"Complex_CoT\"]\n",
    "    outputs = examples[\"Response\"]\n",
    "    texts = []\n",
    "    for input, cot, output in zip(inputs, cots, outputs):\n",
    "        # 将 EOS Token 添加到样本最后\n",
    "        text = train_prompt.format(input, cot, output) + EOS_TOKEN\n",
    "        texts.append(text)\n",
    "    return { \"text\" : texts, }\n",
    "pass\n",
    "\n",
    "from datasets import load_dataset\n",
    "dataset = load_dataset(\"FreedomIntelligence/medical-o1-reasoning-SFT\", \"zh\", split = \"train\")\n",
    "dataset = dataset.map(formatting_prompts_func, batched = True,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'以下是一条描述任务的指令，并配有一个提供进一步上下文的输入。\\n请撰写一份恰当的回复，以完成该请求。\\n在回答之前，请仔细思考该问题，并构建一个分步的思考过程，以确保回应的逻辑严谨和内容准确。\\n\\n\\n### Instruction:\\n你是一位医学专家，在临床推理、诊断学和治疗规划方面拥有深厚的专业知识。\\n请回答以下医学问题。\\n\\n### Question:\\n根据描述，一个1岁的孩子在夏季头皮出现多处小结节，长期不愈合，且现在疮大如梅，溃破流脓，口不收敛，头皮下有空洞，患处皮肤增厚。这种病症在中医中诊断为什么病？\\n\\n### Response:\\n```\\n<think>\\n这个小孩子在夏天头皮上长了些小结节，一直都没好，后来变成了脓包，流了好多脓。想想夏天那么热，可能和湿热有关。才一岁的小孩，免疫力本来就不强，夏天的湿热没准就侵袭了身体。\\n\\n用中医的角度来看，出现小结节、再加上长期不愈合，这些症状让我想到了头疮。小孩子最容易得这些皮肤病，主要因为湿热在体表郁结。\\n\\n但再看看，头皮下还有空洞，这可能不止是简单的头疮。看起来病情挺严重的，也许是脓肿没治好。这样的情况中医中有时候叫做禿疮或者湿疮，也可能是另一种情况。\\n\\n等一下，头皮上的空洞和皮肤增厚更像是疾病已经深入到头皮下，这是不是说明有可能是流注或瘰疬？这些名字常描述头部或颈部的严重感染，特别是有化脓不愈合，又形成通道或空洞的情况。\\n\\n仔细想想，我怎么感觉这些症状更贴近瘰疬的表现？尤其考虑到孩子的年纪和夏天发生的季节性因素，湿热可能是主因，但可能也有火毒或者痰湿造成的滞留。\\n\\n回到基本的症状描述上看，这种长期不愈合又复杂的状况，如果结合中医更偏重的病名，是不是有可能是涉及更深层次的感染？\\n\\n再考虑一下，这应该不是单纯的瘰疬，得仔细分析头皮增厚并出现空洞这样的严重症状。中医里头，这样的表现可能更符合‘蚀疮’或‘头疽’。这些病名通常描述头部严重感染后的溃烂和组织坏死。\\n\\n看看季节和孩子的体质，夏天又湿又热，外邪很容易侵入头部，对孩子这么弱的免疫系统简直就是挑战。头疽这个病名听起来真是切合，因为它描述的感染严重，溃烂到出现空洞。\\n\\n不过，仔细琢磨后发现，还有个病名似乎更为合适，叫做‘蝼蛄疖’，这病在中医里专指像这种严重感染并伴有深部空洞的情况。它也涵盖了化脓和皮肤增厚这些症状。\\n\\n哦，该不会是夏季湿热，导致湿毒入侵，孩子的体质不能御，其病情发展成这样的感染？综合分析后我觉得‘蝼蛄疖’这个病名真是相当符合。\\n</think>\\n```\\n从中医的角度来看，你所描述的症状符合“蝼蛄疖”的病症。这种病症通常发生在头皮，表现为多处结节，溃破流脓，形成空洞，患处皮肤增厚且长期不愈合。湿热较重的夏季更容易导致这种病症的发展，特别是在免疫力较弱的儿童身上。建议结合中医的清热解毒、祛湿消肿的治疗方法进行处理，并配合专业的医疗建议进行详细诊断和治疗。\\n<｜end▁of▁sentence｜>'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[0][\"text\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "以下是一条描述任务的指令，并配有一个提供进一步上下文的输入。\n",
       "请撰写一份恰当的回复，以完成该请求。\n",
       "在回答之前，请仔细思考该问题，并构建一个分步的思考过程，以确保回应的逻辑严谨和内容准确。\n",
       "\n",
       "\n",
       "### Instruction:\n",
       "你是一位医学专家，在临床推理、诊断学和治疗规划方面拥有深厚的专业知识。\n",
       "请回答以下医学问题。\n",
       "\n",
       "### Question:\n",
       "根据描述，一个1岁的孩子在夏季头皮出现多处小结节，长期不愈合，且现在疮大如梅，溃破流脓，口不收敛，头皮下有空洞，患处皮肤增厚。这种病症在中医中诊断为什么病？\n",
       "\n",
       "### Response:\n",
       "```\n",
       "<think>\n",
       "这个小孩子在夏天头皮上长了些小结节，一直都没好，后来变成了脓包，流了好多脓。想想夏天那么热，可能和湿热有关。才一岁的小孩，免疫力本来就不强，夏天的湿热没准就侵袭了身体。\n",
       "\n",
       "用中医的角度来看，出现小结节、再加上长期不愈合，这些症状让我想到了头疮。小孩子最容易得这些皮肤病，主要因为湿热在体表郁结。\n",
       "\n",
       "但再看看，头皮下还有空洞，这可能不止是简单的头疮。看起来病情挺严重的，也许是脓肿没治好。这样的情况中医中有时候叫做禿疮或者湿疮，也可能是另一种情况。\n",
       "\n",
       "等一下，头皮上的空洞和皮肤增厚更像是疾病已经深入到头皮下，这是不是说明有可能是流注或瘰疬？这些名字常描述头部或颈部的严重感染，特别是有化脓不愈合，又形成通道或空洞的情况。\n",
       "\n",
       "仔细想想，我怎么感觉这些症状更贴近瘰疬的表现？尤其考虑到孩子的年纪和夏天发生的季节性因素，湿热可能是主因，但可能也有火毒或者痰湿造成的滞留。\n",
       "\n",
       "回到基本的症状描述上看，这种长期不愈合又复杂的状况，如果结合中医更偏重的病名，是不是有可能是涉及更深层次的感染？\n",
       "\n",
       "再考虑一下，这应该不是单纯的瘰疬，得仔细分析头皮增厚并出现空洞这样的严重症状。中医里头，这样的表现可能更符合‘蚀疮’或‘头疽’。这些病名通常描述头部严重感染后的溃烂和组织坏死。\n",
       "\n",
       "看看季节和孩子的体质，夏天又湿又热，外邪很容易侵入头部，对孩子这么弱的免疫系统简直就是挑战。头疽这个病名听起来真是切合，因为它描述的感染严重，溃烂到出现空洞。\n",
       "\n",
       "不过，仔细琢磨后发现，还有个病名似乎更为合适，叫做‘蝼蛄疖’，这病在中医里专指像这种严重感染并伴有深部空洞的情况。它也涵盖了化脓和皮肤增厚这些症状。\n",
       "\n",
       "哦，该不会是夏季湿热，导致湿毒入侵，孩子的体质不能御，其病情发展成这样的感染？综合分析后我觉得‘蝼蛄疖’这个病名真是相当符合。\n",
       "</think>\n",
       "```\n",
       "从中医的角度来看，你所描述的症状符合“蝼蛄疖”的病症。这种病症通常发生在头皮，表现为多处结节，溃破流脓，形成空洞，患处皮肤增厚且长期不愈合。湿热较重的夏季更容易导致这种病症的发展，特别是在免疫力较弱的儿童身上。建议结合中医的清热解毒、祛湿消肿的治疗方法进行处理，并配合专业的医疗建议进行详细诊断和治疗。\n",
       "<｜end▁of▁sentence｜>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import display, Markdown\n",
    "\n",
    "display(Markdown(dataset[0][\"text\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. 使用 Unsloth 添加 LoRA 适配器\n",
    "\n",
    "这是使用 `unsloth` 的核心步骤。我们调用 `FastLanguageModel.get_peft_model`，它会非常高效地为模型注入 LoRA 模块。\n",
    "\n",
    "- `r`: LoRA 的秩 (rank)，是控制模型复杂度和参数量的关键超参数。\n",
    "- `target_modules`: 指定要在哪些线性层（如注意力机制的 q, k, v, o 投影层）上应用 LoRA。\n",
    "- `lora_alpha`: LoRA 的缩放因子，通常设置为 `r` 的两倍或与 `r` 相同。\n",
    "- `use_gradient_checkpointing`: 一种节省显存的技术，对于训练大模型至关重要。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Unsloth 2025.8.5 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PeftModelForCausalLM(\n",
      "  (base_model): LoraModel(\n",
      "    (model): Qwen2ForCausalLM(\n",
      "      (model): Qwen2Model(\n",
      "        (embed_tokens): Embedding(151936, 1536, padding_idx=151654)\n",
      "        (layers): ModuleList(\n",
      "          (0): Qwen2DecoderLayer(\n",
      "            (self_attn): Qwen2Attention(\n",
      "              (q_proj): lora.Linear(\n",
      "                (base_layer): Linear(in_features=1536, out_features=1536, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (k_proj): lora.Linear(\n",
      "                (base_layer): Linear(in_features=1536, out_features=256, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=256, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (v_proj): lora.Linear(\n",
      "                (base_layer): Linear(in_features=1536, out_features=256, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=256, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (o_proj): lora.Linear(\n",
      "                (base_layer): Linear(in_features=1536, out_features=1536, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (rotary_emb): LlamaRotaryEmbedding()\n",
      "            )\n",
      "            (mlp): Qwen2MLP(\n",
      "              (gate_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=8960, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=8960, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (up_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=8960, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=8960, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (down_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=8960, out_features=1536, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=8960, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (act_fn): SiLU()\n",
      "            )\n",
      "            (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
      "            (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
      "          )\n",
      "          (1-2): 2 x Qwen2DecoderLayer(\n",
      "            (self_attn): Qwen2Attention(\n",
      "              (q_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=1536, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (k_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=256, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=256, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (v_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=256, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=256, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (o_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=1536, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (rotary_emb): LlamaRotaryEmbedding()\n",
      "            )\n",
      "            (mlp): Qwen2MLP(\n",
      "              (gate_proj): lora.Linear(\n",
      "                (base_layer): Linear(in_features=1536, out_features=8960, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=8960, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (up_proj): lora.Linear(\n",
      "                (base_layer): Linear(in_features=1536, out_features=8960, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=8960, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (down_proj): lora.Linear(\n",
      "                (base_layer): Linear(in_features=8960, out_features=1536, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=8960, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (act_fn): SiLU()\n",
      "            )\n",
      "            (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
      "            (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
      "          )\n",
      "          (3-25): 23 x Qwen2DecoderLayer(\n",
      "            (self_attn): Qwen2Attention(\n",
      "              (q_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=1536, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (k_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=256, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=256, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (v_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=256, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=256, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (o_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=1536, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (rotary_emb): LlamaRotaryEmbedding()\n",
      "            )\n",
      "            (mlp): Qwen2MLP(\n",
      "              (gate_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=8960, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=8960, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (up_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=8960, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=8960, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (down_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=8960, out_features=1536, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=8960, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (act_fn): SiLU()\n",
      "            )\n",
      "            (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
      "            (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
      "          )\n",
      "          (26): Qwen2DecoderLayer(\n",
      "            (self_attn): Qwen2Attention(\n",
      "              (q_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=1536, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (k_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=256, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=256, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (v_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=256, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=256, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (o_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=1536, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (rotary_emb): LlamaRotaryEmbedding()\n",
      "            )\n",
      "            (mlp): Qwen2MLP(\n",
      "              (gate_proj): lora.Linear(\n",
      "                (base_layer): Linear(in_features=1536, out_features=8960, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=8960, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (up_proj): lora.Linear(\n",
      "                (base_layer): Linear(in_features=1536, out_features=8960, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=8960, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (down_proj): lora.Linear(\n",
      "                (base_layer): Linear(in_features=8960, out_features=1536, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=8960, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (act_fn): SiLU()\n",
      "            )\n",
      "            (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
      "            (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
      "          )\n",
      "          (27): Qwen2DecoderLayer(\n",
      "            (self_attn): Qwen2Attention(\n",
      "              (q_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=1536, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (k_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=256, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=256, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (v_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=256, bias=True)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=256, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (o_proj): lora.Linear(\n",
      "                (base_layer): Linear(in_features=1536, out_features=1536, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (rotary_emb): LlamaRotaryEmbedding()\n",
      "            )\n",
      "            (mlp): Qwen2MLP(\n",
      "              (gate_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=8960, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=8960, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (up_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=1536, out_features=8960, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=1536, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=8960, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (down_proj): lora.Linear4bit(\n",
      "                (base_layer): Linear4bit(in_features=8960, out_features=1536, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=8960, out_features=16, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=16, out_features=1536, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "                (lora_magnitude_vector): ModuleDict()\n",
      "              )\n",
      "              (act_fn): SiLU()\n",
      "            )\n",
      "            (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
      "            (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
      "          )\n",
      "        )\n",
      "        (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
      "        (rotary_emb): LlamaRotaryEmbedding()\n",
      "      )\n",
      "      (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
      "    )\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# 因为 `model` 对象现在是由 Unsloth 创建的，它包含了所有必需的属性\n",
    "model = FastLanguageModel.get_peft_model(\n",
    "    model,\n",
    "    r=16,\n",
    "    target_modules=[\n",
    "      \"q_proj\",\n",
    "      \"k_proj\",\n",
    "      \"v_proj\",\n",
    "      \"o_proj\",\n",
    "      \"gate_proj\",\n",
    "      \"up_proj\",\n",
    "      \"down_proj\",\n",
    "    ],\n",
    "    lora_alpha=16,\n",
    "    lora_dropout=0,\n",
    "    bias=\"none\",\n",
    "    use_gradient_checkpointing=\"unsloth\",\n",
    "    random_state=1432,\n",
    "    use_rslora=False,\n",
    "    loftq_config=None,\n",
    ")\n",
    "# 检查模型结构，确认 LoRA 适配器已添加\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6. 配置 SFTTrainer\n",
    "\n",
    "`SFTTrainer` (Supervised Fine-tuning Trainer) 是一个专门用于指令微调的训练器。我们需要配置 `TrainingArguments` 来指定所有的训练参数，如批量大小、学习率、优化器等。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "74ca80dee65f42c491153afa8ad518f7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Unsloth: Tokenizing [\"text\"] (num_proc=2):   0%|          | 0/20171 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from trl import SFTConfig, SFTTrainer\n",
    "trainer = SFTTrainer(\n",
    "    model = model,\n",
    "    tokenizer = tokenizer,\n",
    "    train_dataset = dataset,\n",
    "    dataset_text_field = \"text\",\n",
    "    max_seq_length = max_seq_length,\n",
    "    packing = False, # Can make training 5x faster for short sequences.\n",
    "    args = SFTConfig(\n",
    "        per_device_train_batch_size = 64,\n",
    "        gradient_accumulation_steps = 2,\n",
    "        warmup_steps = 5,\n",
    "        # num_train_epochs = 1, # Set this for 1 full training run.\n",
    "        max_steps = 60,\n",
    "        learning_rate = 2e-4,\n",
    "        logging_steps = 1,\n",
    "        optim = \"adamw_8bit\",\n",
    "        weight_decay = 0.01,\n",
    "        lr_scheduler_type = \"linear\",\n",
    "        seed = 1432,\n",
    "        output_dir = \"outputs\",\n",
    "        report_to = \"none\", # Use this for WandB etc\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7. 开始训练\n",
    "\n",
    "一切准备就绪后，调用 `trainer.train()` 即可开始微调过程。训练结束后，会返回包含训练统计信息（如训练损失）的对象。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1\n",
      "   \\\\   /|    Num examples = 20,171 | Num Epochs = 1 | Total steps = 60\n",
      "O^O/ \\_/ \\    Batch size per device = 64 | Gradient accumulation steps = 2\n",
      "\\        /    Data Parallel GPUs = 1 | Total batch size (64 x 2 x 1) = 128\n",
      " \"-____-\"     Trainable parameters = 18,464,768 of 1,795,552,768 (1.03% trained)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Unsloth: Will smartly offload gradients to save VRAM!\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='60' max='60' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [60/60 1:11:13, Epoch 0/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>3.171100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>3.177300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>3.084400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>3.150200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>3.110500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>2.961400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>2.947300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>2.919600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>2.883200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>2.780500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>2.783200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>2.690500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>2.676900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>2.603300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>2.637200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>2.586800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>2.505500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>2.496300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>2.429800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>2.431900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>2.444400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>2.390300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>2.380000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>2.353500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25</td>\n",
       "      <td>2.368300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>2.327700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27</td>\n",
       "      <td>2.302600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28</td>\n",
       "      <td>2.321900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29</td>\n",
       "      <td>2.318700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>2.313800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31</td>\n",
       "      <td>2.302800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>32</td>\n",
       "      <td>2.361800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>33</td>\n",
       "      <td>2.280600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>34</td>\n",
       "      <td>2.285300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>35</td>\n",
       "      <td>2.313500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>36</td>\n",
       "      <td>2.270800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>37</td>\n",
       "      <td>2.310100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>38</td>\n",
       "      <td>2.307400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>39</td>\n",
       "      <td>2.292900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>2.229100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>41</td>\n",
       "      <td>2.280200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>42</td>\n",
       "      <td>2.273100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>43</td>\n",
       "      <td>2.259200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>44</td>\n",
       "      <td>2.236300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>45</td>\n",
       "      <td>2.253000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>46</td>\n",
       "      <td>2.242300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>47</td>\n",
       "      <td>2.269500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>48</td>\n",
       "      <td>2.242800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>49</td>\n",
       "      <td>2.291300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>2.262700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>51</td>\n",
       "      <td>2.236600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>52</td>\n",
       "      <td>2.291700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>53</td>\n",
       "      <td>2.272200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>54</td>\n",
       "      <td>2.224700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>55</td>\n",
       "      <td>2.251400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>56</td>\n",
       "      <td>2.282000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>57</td>\n",
       "      <td>2.223500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>58</td>\n",
       "      <td>2.226300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>59</td>\n",
       "      <td>2.218400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>2.229800</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TrainOutput(global_step=60, training_loss=2.4594939827919005, metrics={'train_runtime': 4346.4066, 'train_samples_per_second': 1.767, 'train_steps_per_second': 0.014, 'total_flos': 7.19252239613952e+16, 'train_loss': 2.4594939827919005})\n"
     ]
    }
   ],
   "source": [
    "trainer_stats = trainer.train()\n",
    "\n",
    "# 打印训练统计信息\n",
    "print(trainer_stats)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 8. 保存微调后的模型（Lora）\n",
    "\n",
    "训练完成后，您可以再次进行推理，比较微调后的模型与原始模型的差异。如果对结果满意，可以使用 `model.save_pretrained(\"your_lora_adapter_path\")` 来保存训练好的 LoRA 适配器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained(\"qwen-1.5b_lora_model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('qwen-1.5b_lora_model/tokenizer_config.json',\n",
       " 'qwen-1.5b_lora_model/special_tokens_map.json',\n",
       " 'qwen-1.5b_lora_model/chat_template.jinja',\n",
       " 'qwen-1.5b_lora_model/tokenizer.json')"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.save_pretrained(\"qwen-1.5b_lora_model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模型保存方式二选一（要么使用上面的分开保存，要么使用这里的合并 Lora 保存）\n",
    "# model.save_pretrained_merged(\"qwen-1.5b_lora_model\", tokenizer, save_method=\"merged_16bit\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 9. 测试训练后的生成结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "FastLanguageModel.for_inference(model) # Enable native 2x faster inference\n",
    "\n",
    "question=\"一个患有急性阑尾炎的病人已经发病5天，腹痛稍有减轻但仍然发热，在体检时发现右下腹有压痛的包块，此时应如何处理？\", # Question\n",
    "inputs = tokenizer([inference_prompt.format(question, \"\")], return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "outputs = model.generate(\n",
    "    input_ids=inputs.input_ids,\n",
    "    attention_mask=inputs.attention_mask,\n",
    "    max_new_tokens=1000,\n",
    "    use_cache=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "<think>\n",
      "这个病人已经5天了，急性阑尾炎，感觉腹痛只是稍微减轻了一点，但是还是有点发热。所以，他的疼痛和发热情况看起来还是挺正常的。\n",
      "\n",
      "根据他的症状，我觉得他应该还是有办法去处理的。就是说，虽然疼痛和发热看起来有点轻，但不要过于担心，毕竟阑尾炎确实不是那么容易控制的。\n",
      "\n",
      "嗯，想到这里，我想到要找一个合适的治疗方法，就是把疼痛和发热都控制住，然后看看他的疼痛和发热情况是否还有其他变化。\n",
      "\n",
      "可能，这个病人需要进行一些排痛和抗炎治疗，但前提是他的疼痛和发热症状没有太大的改善。如果症状没有改善，那么可能需要考虑更复杂的治疗方法，比如手术。\n",
      "\n",
      "不过，从他的症状来看，疼痛和发热已经有所减轻，而且在体检时还发现了右下腹的压痛包块，这让我想到，可能可以考虑使用一些药物来缓解疼痛和发热。\n",
      "\n",
      "不过，如果疼痛和发热症状没有改善，或者症状变得越来越重，那么可能需要手术来处理这种情况。\n",
      "\n",
      "所以，综合来看，这个病人在5天的治疗时间里，应该根据症状的具体情况来决定采取什么治疗。\n",
      "</think>\n",
      "根据病人5天的治疗时间，症状有所减轻但发热和疼痛并未改善，因此需要考虑进一步的治疗。在这种情况下，建议使用非侵入性疼痛治疗，如止痛药，以缓解疼痛症状。同时，使用抗炎药物如阿司匹林或阿立克司他来减轻发热。如果症状没有改善，建议进行手术以处理压痛包块。\n"
     ]
    }
   ],
   "source": [
    "output = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "print(output[0].split(\"### Response:\")[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_response(question: str, model, tokenizer, inference_prompt: str, max_new_tokens: int = 1024) -> str:\n",
    "    \"\"\"\n",
    "    使用指定的模型和分词器为给定的医学问题生成响应。\n",
    "\n",
    "    Args:\n",
    "        question (str): 需要模型回答的医学问题。\n",
    "        model: 已加载的 Unsloth/Hugging Face 模型。\n",
    "        tokenizer: 对应的分词器。\n",
    "        inference_prompt (str): 用于格式化输入的 f-string 模板。\n",
    "        max_new_tokens (int, optional): 生成响应的最大 token 数量。默认为 1024。\n",
    "\n",
    "    Returns:\n",
    "        str: 模型生成的响应文本，已去除 prompt 部分。\n",
    "    \"\"\"\n",
    "    # 1. 使用模板格式化输入\n",
    "    prompt = inference_prompt.format(\n",
    "        question, # 填充问题\n",
    "        \"\",       # 留空，让模型生成 CoT 和 Response\n",
    "    )\n",
    "\n",
    "    # 2. 将格式化后的 prompt 进行分词，并转移到 GPU\n",
    "    inputs = tokenizer([prompt], return_tensors=\"pt\").to(model.device)\n",
    "\n",
    "    # 3. 使用模型生成输出\n",
    "    # use_cache=True 用于加速解码过程\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs.input_ids,\n",
    "        attention_mask=inputs.attention_mask,\n",
    "        max_new_tokens=max_new_tokens,\n",
    "        use_cache=True,\n",
    "    )\n",
    "    \n",
    "    # 4. 将生成的 token 解码为文本\n",
    "    # skip_special_tokens=True 会移除像 EOS_TOKEN 这样的特殊标记\n",
    "    decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]\n",
    "\n",
    "    # 5. 切分字符串，只返回 \"### Response:\" 之后的部分\n",
    "    # 使用 .split() 分割并获取响应内容，.strip() 用于去除可能存在的前后空白字符\n",
    "    response_part = decoded_output.split(\"### Response:\")\n",
    "    if len(response_part) > 1:\n",
    "        return response_part[1].strip()\n",
    "    else:\n",
    "        # 如果模型没有生成 \"### Response:\" 标记，则返回整个生成内容以供调试\n",
    "        return decoded_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================== 模型回答 ====================\n",
      "<think>\n",
      "嗯，这位60岁的男性患者有右侧胸疼，X线检查显示右侧肋膈角消失，这让我想到可能有肺结核的迹象。肺结核通常会导致胸腔积液，特别是右侧的。要弄清楚这种积液的具体性质，实验室检查应该是关键。\n",
      "\n",
      "首先，检查胸腔液体的酸碱度，能告诉我这个积液是酸性还是碱性。酸性或碱性可以帮助我们确定是什么类型的液体，比如乳酸水或者乳酸钠。如果怀疑是乳酸水，可以用乳酸水酶的检测方法来进一步确认。\n",
      "\n",
      "然后，看看是否有乳酸钠的存在，这在肺结核患者中非常重要，因为乳酸钠可以释放乳酸，帮助维持酸性环境。如果检测到乳酸钠，就能进一步验证积液的类型。\n",
      "\n",
      "此外，检测尿液中的乳酸含量也很重要，因为这能直接反映肺结核患者的乳酸含量，从而判断是否已经产生了乳酸水。\n",
      "\n",
      "最后，用血清中的血钙含量来判断血钙水平，因为血钙升高可能提示肺结核患者的血钙水平较高，也可能暗示了乳酸含量增加。\n",
      "\n",
      "综上所述，通过这些检查，我们能更准确地了解胸腔积液的具体性质和类型，从而做出更准确的诊断。\n",
      "</think>\n",
      "```\n",
      "在诊断肺结核伴右侧胸腔积液时，了解胸腔液体的酸碱度和乳酸含量是关键。以下是这些检查的步骤：\n",
      "\n",
      "1. **检查胸腔液体的酸碱度**：通过酸碱度计检测，可以确定胸腔液体是酸性还是碱性。酸性或碱性可以帮助我们初步判断是乳酸水或乳酸钠。\n",
      "\n",
      "2. **检测乳酸钠的存在**：如果怀疑是乳酸水，可以用乳酸水酶来检测，以确认乳酸钠的存在。\n",
      "\n",
      "3. **尿液中的乳酸检测**：尿液中的乳酸含量可以反映肺结核患者的乳酸含量，从而判断是否已开始产生乳酸水。\n",
      "\n",
      "4. **血清中的血钙检测**：血钙水平的变化可以提示血钙的水平，从而判断是否为肺结核患者的血钙水平较高。\n",
      "\n",
      "通过这些检查，我们能够更准确地了解胸腔积液的具体性质和类型，从而做出更准确的诊断。\n"
     ]
    }
   ],
   "source": [
    "my_question = \"对于一名60岁男性患者，出现右侧胸疼并在X线检查中显示右侧肋膈角消失，诊断为肺结核伴右侧胸腔积液，请问哪一项实验室检查对了解胸水的性质更有帮助？\"\n",
    "\n",
    "response = generate_response(my_question, model, tokenizer, inference_prompt)\n",
    "print(\"==================== 模型回答 ====================\")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================== 模型回答 ====================\n",
      "<think>\n",
      "这位28岁的男性患者，工作是程序员，长期熬夜，最近突然感觉头晕目眩，甚至有点恶心，这让我想到他可能得了脑损伤性晕眩症。\n",
      "\n",
      "脑损伤性晕眩症是一种常见的神经系统疾病，尤其是脑损伤后，导致脑缺氧和缺血，从而引起晕眩等症状。这种疾病通常表现为脑缺氧，进而导致晕眩，尤其在长时间的脑损伤后，晕眩症状会更加明显。\n",
      "\n",
      "再仔细想想，这位患者长时间熬夜，这可能是因为脑缺氧导致的疲劳，甚至更可能导致脑损伤性晕眩。这种情况下，晕眩不仅仅是简单的疲劳，而是与脑缺氧和缺血直接相关。\n",
      "\n",
      "脑损伤性晕眩症的特征包括脑缺氧、缺血、缺氧血流量和脑缺氧血流量，这些都是导致晕眩的重要原因。这种疾病也常伴随脑水肿和脑损伤，所以患者的情况可能和脑水肿有关。\n",
      "\n",
      "不过，脑损伤性晕眩症的病因其实不是单一的，它不仅受脑损伤的影响，还可能受到脑损伤前脑水肿的影响。这种情况下，患者可能需要考虑脑水肿作为可能的病因。\n",
      "\n",
      "总的来说，这位患者可能得了脑损伤性晕眩症，因为他的症状和病因都符合脑损伤后出现的典型表现。不过，由于病因复杂，建议他尽快去看脑水肿专家，以排除脑水肿的可能，从而更好地诊断和治疗。\n",
      "</think>\n",
      "```\n",
      "根据患者的症状和病因分析，这位28岁的男性患者，工作是程序员，长期熬夜，最近突然感觉头晕目眩，甚至有点恶心，这提示他可能得了脑损伤性晕眩症。脑损伤性晕眩症是一种常见的神经系统疾病，尤其是脑损伤后，导致脑缺氧和缺血，从而引起晕眩。这种疾病通常表现为脑缺氧，进而导致晕眩，尤其在长时间的脑损伤后，晕眩症状会更加明显。\n",
      "\n",
      "此外，脑损伤性晕眩症的病因也受脑损伤前脑水肿的影响，这种情况下，患者可能需要考虑脑水肿作为可能的病因。因此，建议尽快查看脑水肿专家，以排除脑水肿的可能，从而更好地诊断和治疗。\n",
      "\n",
      "综上所述，这位患者可能得了脑损伤性晕眩症，但建议尽快寻求专业诊断，以确保其病情得到及时准确的治疗。\n"
     ]
    }
   ],
   "source": [
    "my_question = \"对于一名 28 岁的男性患者，工作是程序员，常年熬夜，最近突然感觉头晕目眩，甚至有点恶心。请问有可能是什么疾病？\"\n",
    "\n",
    "response = generate_response(my_question, model, tokenizer, inference_prompt)\n",
    "print(\"==================== 模型回答 ====================\")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
